import argparse
import torch
import numpy as np
import gym
import os
import random
from tqdm import tqdm
from torch.optim import Adam
import matplotlib.pyplot as plt
from noda.noda import NODA, NODANoPartial
from simulators.noda import AE
from sac.sac import ReplayBuffer
import pdb


def compute_loss_model(model, data):
    o, a, r, o2, d = data['obs'], data['act'], data['rew'], data['obs2'], data['done']
    o2_pred, r_pred, o_recon = model(o, a)
    loss_o_pred = ((o2_pred - o2) ** 2).mean(dim=1)
    # loss_r_pred = (r_pred - r) ** 2
    loss_o_recon = ((o_recon - o) ** 2).mean(dim=1)
    loss_model = 0.5 * (loss_o_pred + loss_o_recon)
    return loss_model


def get_buffer(env, steps, device, max_ep_len=1000):
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]
    buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=steps, device=device)
    o, ep_ret, ep_len = env.reset(), 0, 0
    with tqdm(total=steps, desc='Generating data') as t:
        for i in range(steps):
            a = env.action_space.sample()
            o2, r, d, _ = env.step(a)
            ep_ret += r
            ep_len += 1
            d = False if ep_len == max_ep_len else d
            buffer.store(o, a, r, o2, d)
            o = o2
            if d or (ep_len == max_ep_len):
                o, ep_ret, ep_len = env.reset(), 0, 0
            t.update()
    return buffer


def get_buffer_two_step(env, steps, device, max_ep_len=1000):
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape[0]
    buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=steps, device=device)
    o1, ep_ret, ep_len = env.reset(), 0, 0
    a1 = env.action_space.sample()
    o2, r2, d2, _ = env.step(a1)
    ep_len += 1
    assert not d2
    with tqdm(total=steps, desc='Generating data') as t:
        for i in range(steps):
            a2 = env.action_space.sample()
            o3, r3, d3, _ = env.step(a2)
            ep_ret += r2 + r3
            ep_len += 1
            d3 = False if ep_len == max_ep_len else d3
            buffer.store(o1, (a1 + a2) / 2, r2 + r3, o3, d3)
            o1 = o2
            o2 = o3
            r2 = r3
            a1 = a2
            if d3 or (ep_len == max_ep_len):
                o1, ep_ret, ep_len = env.reset(), 0, 0
                a1 = env.action_space.sample()
                o2, r2, d2, _ = env.step(a1)
                ep_len += 1
                assert not d2
            t.update()
    return buffer


def train(args, model, train_buffer, test_buffer=None, model_steps=None):
    model_optimizer = Adam(model.parameters(), lr=args.lr)
    train_loss_list = []
    test_loss_list = []
    if model_steps is None:
        model_steps = args.model_steps
    with tqdm(total=model_steps) as t:
        for step in range(model_steps):
            ixs = torch.randperm(args.env_steps)[:args.batch_size]
            loss = compute_loss_model(model, train_buffer.get_batch(ixs))
            model_optimizer.zero_grad()
            loss.mean().backward()
            model_optimizer.step()
            with torch.no_grad():
                train_loss = compute_loss_model(model, train_buffer.get_batch(torch.randperm(args.env_steps)))
                train_loss_list.append(train_loss.cpu().numpy())
                if test_buffer is not None:
                    test_loss = compute_loss_model(model, test_buffer.get_batch(torch.randperm(args.env_steps)))
                    test_loss_list.append(test_loss.cpu().numpy())
                    t.set_postfix(train_loss='{:.9f}'.format(train_loss.mean().item()),
                                  test_loss='{:.9f}'.format(test_loss.mean().item()))
                else:
                    t.set_postfix(train_loss='{:.9f}'.format(train_loss.mean().item()))
            t.update()
    return model, np.array(train_loss_list), np.array(test_loss_list)


def transfer_plot(args, target_path, labels=None):
    if labels is None:
        labels = [['Original (training)', 'Transferred (training)'],
                  ['Original (testing)', 'Transferred (testing)']]
    results = np.load(target_path, allow_pickle=True)
    train_loss_models = results['train_loss_models']
    test_loss_models = results['test_loss_models']
    save_path = args.save_dir + '/'
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5), sharey='row')
    ax = axes.flatten()
    step = np.arange(1, args.model_steps + 1)
    for i in range(2):
        ax[0].plot(step, train_loss_models[i].mean(axis=-1), label=labels[0][i])
        ax[0].fill_between(step, train_loss_models[i].mean(axis=-1) - train_loss_models[i].std(axis=-1),
                           train_loss_models[i].mean(axis=-1) + train_loss_models[i].std(axis=-1), alpha=0.3)
        ax[1].plot(step, test_loss_models[i].mean(axis=-1), label=labels[1][i])
        ax[1].fill_between(step, test_loss_models[i].mean(axis=-1) - test_loss_models[i].std(axis=-1),
                           test_loss_models[i].mean(axis=-1) + test_loss_models[i].std(axis=-1), alpha=0.3)
    axes[0].grid(True)
    axes[1].grid(True)
    axes[0].set_xlabel('Step')
    axes[1].set_xlabel('Step')
    axes[0].set_ylabel('Loss')
    axes[0].legend(loc='best')
    axes[1].legend(loc='best')
    save_path += 'transfer_' + args.env + '_' + str(args.hid_noda_ae) + '_' + \
                  str(args.hid_noda_ode) + '_' + str(args.env_steps) + '_' + str(args.model_steps) + '.pdf'
    plt.savefig(save_path)
    plt.close()


def transfer_plot_testing(args, target_path, labels=None):
    if labels is None:
        labels = ['NODA-Original', 'NODA-Transferred']
    results = np.load(target_path, allow_pickle=True)
    test_loss_models = results['test_loss_models']
    save_path = args.save_dir + '/'
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(7, 5))
    step = np.arange(1, args.model_steps + 1)
    for i in range(len(labels)):
        ax.plot(step, test_loss_models[i].mean(axis=-1), label=labels[i])
        ax.fill_between(step, test_loss_models[i].mean(axis=-1) - test_loss_models[i].std(axis=-1),
                        test_loss_models[i].mean(axis=-1) + test_loss_models[i].std(axis=-1), alpha=0.3)
    ax.set_xlabel('Steps')
    ax.set_ylabel('Testing Loss')
    ax.legend(loc='best')
    ax.set_title(args.env)
    ax.grid(True)
    save_path += 'transfer_' + args.env + '_' + str(args.hid_noda_ae) + '_' + \
                  str(args.hid_noda_ode) + '_' + str(args.env_steps) + '_' + str(args.model_steps) + '_testing.pdf'
    plt.savefig(save_path)
    plt.close()


def transfer_main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Ant-v3')
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--exp_name', type=str, default='noda')
    parser.add_argument('--lat-noda', type=int, default=40)
    parser.add_argument('--hid-noda-ae', type=int, default=256)
    parser.add_argument('--hid-noda-ode', type=int, default=64)
    parser.add_argument('--env-steps', type=int, default=20000)
    parser.add_argument('--batch-size', type=int, default=256)
    parser.add_argument('--model-steps', type=int, default=1000)
    parser.add_argument('--lr', type=float, default=0.001)
    parser.add_argument('--retrain', action='store_true', default=False)
    parser.add_argument('--save-dir', default='results/transfer', type=str)
    args = parser.parse_args()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    os.environ['PYTHONHASHSEED'] = str(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    if not os.path.isdir(args.save_dir):
        os.makedirs(args.save_dir)
    target_path = args.save_dir + '/transfer_results_' + args.env + '_' + str(args.hid_noda_ae) + '_' + \
                  str(args.hid_noda_ode) + '_' + str(args.env_steps) + '_' + str(args.model_steps) + '.npz'
    if os.path.isfile(target_path) and not args.retrain:
        transfer_plot(args, target_path)
        transfer_plot_testing(args, target_path)
        return None
    env = gym.make(args.env)
    models = [NODANoPartial(env.observation_space, env.action_space,
                            latent_dim=args.lat_noda, hidden_dim_ode=args.hid_noda_ode,
                            hidden_dim_ae=args.hid_noda_ae).to(args.device),
              NODANoPartial(env.observation_space, env.action_space,
                            latent_dim=args.lat_noda, hidden_dim_ode=args.hid_noda_ode,
                            hidden_dim_ae=args.hid_noda_ae).to(args.device)
              ]
    pretrain_buffer = get_buffer(env, args.env_steps, args.device)
    train_buffer = get_buffer_two_step(env, args.env_steps, args.device)
    test_buffer = get_buffer_two_step(env, args.env_steps, args.device)
    models[1], _, _ = train(args, models[1], pretrain_buffer, model_steps=args.model_steps // 10)
    train_loss_models = []
    test_loss_models = []
    for i in range(len(models)):
        model = models[i]
        if i == 1:
            model.integration_time *= 2
        model, train_loss, test_loss = train(args, model, train_buffer, test_buffer)
        train_loss_models.append(train_loss)
        test_loss_models.append(test_loss)
    train_loss_models = np.array(train_loss_models)
    test_loss_models = np.array(test_loss_models)
    np.savez(target_path, train_loss_models=train_loss_models, test_loss_models=test_loss_models)
    transfer_plot(args, target_path)
    transfer_plot_testing(args, target_path)


if __name__ == '__main__':
    plt.rcParams['font.sans-serif'] = ['Times New Roman']
    plt.rcParams.update({'figure.autolayout': True})
    # plt.rcParams['xtick.direction'] = 'in'
    # plt.rcParams['ytick.direction'] = 'in'
    plt.rc('font', size=23)
    transfer_main()
